{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pickle\n",
    "from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,\n",
    "                              TensorDataset)\n",
    "from torch.nn import CrossEntropyLoss, MSELoss\n",
    "\n",
    "from tqdm import tqdm_notebook, trange\n",
    "import os\n",
    "from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM, BertForSequenceClassification\n",
    "from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule\n",
    "\n",
    "from multiprocessing import Pool, cpu_count\n",
    "from tools import *\n",
    "import convert_examples_to_features\n",
    "\n",
    "# OPTIONAL: if you want to have more information on what's happening, activate the logger as follows\n",
    "import logging\n",
    "logging.basicConfig(level=logging.INFO)\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The input data dir. Should contain the .tsv files (or other data files) for the task.\n",
    "DATA_DIR = \"data/\"\n",
    "\n",
    "# Bert pre-trained model selected in the list: bert-base-uncased, \n",
    "# bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased,\n",
    "# bert-base-multilingual-cased, bert-base-chinese.\n",
    "BERT_MODEL = 'bert-base-cased'\n",
    "\n",
    "# The name of the task to train.I'm going to name this 'yelp'.\n",
    "TASK_NAME = 'yelp'\n",
    "\n",
    "# The output directory where the fine-tuned model and checkpoints will be written.\n",
    "OUTPUT_DIR = f'outputs/{TASK_NAME}/'\n",
    "\n",
    "# The directory where the evaluation reports will be written to.\n",
    "REPORTS_DIR = f'reports/{TASK_NAME}_evaluation_report/'\n",
    "\n",
    "# This is where BERT will look for pre-trained models to load parameters from.\n",
    "CACHE_DIR = 'cache/'\n",
    "\n",
    "# The maximum total input sequence length after WordPiece tokenization.\n",
    "# Sequences longer than this will be truncated, and sequences shorter than this will be padded.\n",
    "MAX_SEQ_LENGTH = 128\n",
    "\n",
    "TRAIN_BATCH_SIZE = 24\n",
    "EVAL_BATCH_SIZE = 8\n",
    "LEARNING_RATE = 2e-5\n",
    "NUM_TRAIN_EPOCHS = 1\n",
    "RANDOM_SEED = 42\n",
    "GRADIENT_ACCUMULATION_STEPS = 1\n",
    "WARMUP_PROPORTION = 0.1\n",
    "OUTPUT_MODE = 'classification'\n",
    "\n",
    "CONFIG_NAME = \"config.json\"\n",
    "WEIGHTS_NAME = \"pytorch_model.bin\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.path.exists(REPORTS_DIR) and os.listdir(REPORTS_DIR):\n",
    "        REPORTS_DIR += f'/report_{len(os.listdir(REPORTS_DIR))}'\n",
    "        os.makedirs(REPORTS_DIR)\n",
    "if not os.path.exists(REPORTS_DIR):\n",
    "    os.makedirs(REPORTS_DIR)\n",
    "    REPORTS_DIR += f'/report_{len(os.listdir(REPORTS_DIR))}'\n",
    "    os.makedirs(REPORTS_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "if os.path.exists(OUTPUT_DIR) and os.listdir(OUTPUT_DIR):\n",
    "        raise ValueError(\"Output directory ({}) already exists and is not empty.\".format(OUTPUT_DIR))\n",
    "if not os.path.exists(OUTPUT_DIR):\n",
    "    os.makedirs(OUTPUT_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "processor = BinaryClassificationProcessor()\n",
    "train_examples = processor.get_train_examples(DATA_DIR)\n",
    "train_examples_len = len(train_examples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_list = processor.get_labels() # [0, 1] for binary classification\n",
    "num_labels = len(label_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_train_optimization_steps = int(\n",
    "    train_examples_len / TRAIN_BATCH_SIZE / GRADIENT_ACCUMULATION_STEPS) * NUM_TRAIN_EPOCHS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:pytorch_pretrained_bert.tokenization:loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at C:\\Users\\chatu\\.pytorch_pretrained_bert\\5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1\n"
     ]
    }
   ],
   "source": [
    "# Load pre-trained model tokenizer (vocabulary)\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_map = {label: i for i, label in enumerate(label_list)}\n",
    "train_examples_for_processing = [(example, label_map, MAX_SEQ_LENGTH, tokenizer, OUTPUT_MODE) for example in train_examples]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "process_count = cpu_count() - 1\n",
    "if __name__ ==  '__main__':\n",
    "    print(f'Preparing to convert {train_examples_len} examples..')\n",
    "    print(f'Spawning {process_count} processes..')\n",
    "    with Pool(process_count) as p:\n",
    "        train_features = list(tqdm_notebook(p.imap(convert_examples_to_features.convert_example_to_feature, train_examples_for_processing), total=train_examples_len))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# with open(DATA_DIR + \"train_features.pkl\", \"wb\") as f:\n",
    "#     pickle.dump(train_features, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# with open(DATA_DIR + \"train_features.pkl\", \"rb\") as f:\n",
    "#     train_features = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:pytorch_pretrained_bert.modeling:loading archive file cache/cased_base_bert_pytorch.tar.gz\n",
      "INFO:pytorch_pretrained_bert.modeling:extracting archive file cache/cased_base_bert_pytorch.tar.gz to temp dir C:\\Users\\chatu\\AppData\\Local\\Temp\\tmpysx78xfv\n",
      "INFO:pytorch_pretrained_bert.modeling:Model config {\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"vocab_size\": 28996\n",
      "}\n",
      "\n",
      "INFO:pytorch_pretrained_bert.modeling:Weights of BertForSequenceClassification not initialized from pretrained model: ['classifier.weight', 'classifier.bias']\n",
      "INFO:pytorch_pretrained_bert.modeling:Weights from pretrained model not used in BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']\n"
     ]
    }
   ],
   "source": [
    "# Load pre-trained model (weights)\n",
    "# model = BertForSequenceClassification.from_pretrained(CACHE_DIR + 'cased_base_bert_pytorch.tar.gz', cache_dir=CACHE_DIR, num_labels=num_labels)\n",
    "model = BertForSequenceClassification.from_pretrained(BERT_MODEL, cache_dir=CACHE_DIR, num_labels=num_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BertForSequenceClassification(\n",
       "  (bert): BertModel(\n",
       "    (embeddings): BertEmbeddings(\n",
       "      (word_embeddings): Embedding(28996, 768, padding_idx=0)\n",
       "      (position_embeddings): Embedding(512, 768)\n",
       "      (token_type_embeddings): Embedding(2, 768)\n",
       "      (LayerNorm): BertLayerNorm()\n",
       "      (dropout): Dropout(p=0.1)\n",
       "    )\n",
       "    (encoder): BertEncoder(\n",
       "      (layer): ModuleList(\n",
       "        (0): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (1): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (2): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (3): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (4): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (5): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (6): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (7): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (8): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (9): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (10): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "        (11): BertLayer(\n",
       "          (attention): BertAttention(\n",
       "            (self): BertSelfAttention(\n",
       "              (query): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (key): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (value): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "            (output): BertSelfOutput(\n",
       "              (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (LayerNorm): BertLayerNorm()\n",
       "              (dropout): Dropout(p=0.1)\n",
       "            )\n",
       "          )\n",
       "          (intermediate): BertIntermediate(\n",
       "            (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
       "          )\n",
       "          (output): BertOutput(\n",
       "            (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (LayerNorm): BertLayerNorm()\n",
       "            (dropout): Dropout(p=0.1)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (pooler): BertPooler(\n",
       "      (dense): Linear(in_features=768, out_features=768, bias=True)\n",
       "      (activation): Tanh()\n",
       "    )\n",
       "  )\n",
       "  (dropout): Dropout(p=0.1)\n",
       "  (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "param_optimizer = list(model.named_parameters())\n",
    "no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n",
    "optimizer_grouped_parameters = [\n",
    "    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},\n",
    "    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
    "    ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = BertAdam(optimizer_grouped_parameters,\n",
    "                     lr=LEARNING_RATE,\n",
    "                     warmup=WARMUP_PROPORTION,\n",
    "                     t_total=num_train_optimization_steps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "global_step = 0\n",
    "nb_tr_steps = 0\n",
    "tr_loss = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:root:***** Running training *****\n",
      "INFO:root:  Num examples = 560000\n",
      "INFO:root:  Batch size = 24\n",
      "INFO:root:  Num steps = 23333\n"
     ]
    }
   ],
   "source": [
    "logger.info(\"***** Running training *****\")\n",
    "logger.info(\"  Num examples = %d\", train_examples_len)\n",
    "logger.info(\"  Batch size = %d\", TRAIN_BATCH_SIZE)\n",
    "logger.info(\"  Num steps = %d\", num_train_optimization_steps)\n",
    "all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)\n",
    "all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)\n",
    "all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)\n",
    "\n",
    "if OUTPUT_MODE == \"classification\":\n",
    "    all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)\n",
    "elif OUTPUT_MODE == \"regression\":\n",
    "    all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)\n",
    "train_sampler = RandomSampler(train_data)\n",
    "train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=TRAIN_BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch:   0%|                                                                                     | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a1d4f8c0f7064bffba474db1329032dc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, description='Iteration', max=23334, style=ProgressStyle(description_width=…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.053522"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch: 100%|████████████████████████████████████████████████████████████████████████| 1/1 [3:09:55<00:00, 11395.61s/it]\n"
     ]
    }
   ],
   "source": [
    "model.train()\n",
    "for _ in trange(int(NUM_TRAIN_EPOCHS), desc=\"Epoch\"):\n",
    "    tr_loss = 0\n",
    "    nb_tr_examples, nb_tr_steps = 0, 0\n",
    "    for step, batch in enumerate(tqdm_notebook(train_dataloader, desc=\"Iteration\")):\n",
    "        batch = tuple(t.to(device) for t in batch)\n",
    "        input_ids, input_mask, segment_ids, label_ids = batch\n",
    "\n",
    "        logits = model(input_ids, segment_ids, input_mask, labels=None)\n",
    "\n",
    "        if OUTPUT_MODE == \"classification\":\n",
    "            loss_fct = CrossEntropyLoss()\n",
    "            loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))\n",
    "        elif OUTPUT_MODE == \"regression\":\n",
    "            loss_fct = MSELoss()\n",
    "            loss = loss_fct(logits.view(-1), label_ids.view(-1))\n",
    "\n",
    "        if GRADIENT_ACCUMULATION_STEPS > 1:\n",
    "            loss = loss / GRADIENT_ACCUMULATION_STEPS\n",
    "\n",
    "        loss.backward()\n",
    "        print(\"\\r%f\" % loss, end='')\n",
    "        \n",
    "        tr_loss += loss.item()\n",
    "        nb_tr_examples += input_ids.size(0)\n",
    "        nb_tr_steps += 1\n",
    "        if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:\n",
    "            optimizer.step()\n",
    "            optimizer.zero_grad()\n",
    "            global_step += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'outputs/yelp/vocab.txt'"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self\n",
    "\n",
    "# If we save using the predefined names, we can load using `from_pretrained`\n",
    "output_model_file = os.path.join(OUTPUT_DIR, WEIGHTS_NAME)\n",
    "output_config_file = os.path.join(OUTPUT_DIR, CONFIG_NAME)\n",
    "\n",
    "torch.save(model_to_save.state_dict(), output_model_file)\n",
    "model_to_save.config.to_json_file(output_config_file)\n",
    "tokenizer.save_vocabulary(OUTPUT_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
